{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# CatBoost Algorithms"
      ],
      "metadata": {}
    },
    {
      "cell_type": "markdown",
      "source": [
        "CatBoost (categorical boosting) is an algorithm for gradient boosting on decision trees. Is a machine learning technique created by Yandex; therefore, it outperforms many existing boosting such as XGBoost and Light GBM. "
      ],
      "metadata": {}
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import pandas as pd\n",
        "\n",
        "import warnings\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "\n",
        "# fix_yahoo_finance is used to fetch data \n",
        "import fix_yahoo_finance as yf\n",
        "yf.pdr_override()"
      ],
      "outputs": [],
      "execution_count": 1,
      "metadata": {
        "collapsed": false,
        "outputHidden": false,
        "inputHidden": false
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# input\n",
        "symbol = 'AMD'\n",
        "start = '2007-01-01'\n",
        "end = '2018-11-16'\n",
        "\n",
        "# Read data \n",
        "dataset = yf.download(symbol,start,end)\n",
        "\n",
        "# View Columns\n",
        "dataset.head()"
      ],
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[*********************100%***********************]  1 of 1 downloaded\n"
          ]
        },
        {
          "output_type": "execute_result",
          "execution_count": 2,
          "data": {
            "text/plain": [
              "                 Open       High        Low      Close  Adj Close    Volume\n",
              "Date                                                                       \n",
              "2007-01-03  20.080000  20.400000  19.350000  19.520000  19.520000  28350300\n",
              "2007-01-04  19.660000  19.860001  19.320000  19.790001  19.790001  23652500\n",
              "2007-01-05  19.540001  19.910000  19.540001  19.709999  19.709999  15902400\n",
              "2007-01-08  19.709999  19.860001  19.370001  19.469999  19.469999  15814800\n",
              "2007-01-09  19.450001  19.709999  19.370001  19.650000  19.650000  14494200"
            ],
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>Open</th>\n",
              "      <th>High</th>\n",
              "      <th>Low</th>\n",
              "      <th>Close</th>\n",
              "      <th>Adj Close</th>\n",
              "      <th>Volume</th>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Date</th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>2007-01-03</th>\n",
              "      <td>20.080000</td>\n",
              "      <td>20.400000</td>\n",
              "      <td>19.350000</td>\n",
              "      <td>19.520000</td>\n",
              "      <td>19.520000</td>\n",
              "      <td>28350300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2007-01-04</th>\n",
              "      <td>19.660000</td>\n",
              "      <td>19.860001</td>\n",
              "      <td>19.320000</td>\n",
              "      <td>19.790001</td>\n",
              "      <td>19.790001</td>\n",
              "      <td>23652500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2007-01-05</th>\n",
              "      <td>19.540001</td>\n",
              "      <td>19.910000</td>\n",
              "      <td>19.540001</td>\n",
              "      <td>19.709999</td>\n",
              "      <td>19.709999</td>\n",
              "      <td>15902400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2007-01-08</th>\n",
              "      <td>19.709999</td>\n",
              "      <td>19.860001</td>\n",
              "      <td>19.370001</td>\n",
              "      <td>19.469999</td>\n",
              "      <td>19.469999</td>\n",
              "      <td>15814800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2007-01-09</th>\n",
              "      <td>19.450001</td>\n",
              "      <td>19.709999</td>\n",
              "      <td>19.370001</td>\n",
              "      <td>19.650000</td>\n",
              "      <td>19.650000</td>\n",
              "      <td>14494200</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ]
          },
          "metadata": {}
        }
      ],
      "execution_count": 2,
      "metadata": {
        "collapsed": false,
        "outputHidden": false,
        "inputHidden": false
      }
    },
    {
      "cell_type": "code",
      "source": [
        "dataset['Increase_Decrease'] = np.where(dataset['Volume'].shift(-1) > dataset['Volume'],1,0)\n",
        "dataset['Buy_Sell_on_Open'] = np.where(dataset['Open'].shift(-1) > dataset['Open'],1,0)\n",
        "dataset['Buy_Sell'] = np.where(dataset['Adj Close'].shift(-1) > dataset['Adj Close'],1,0)\n",
        "dataset['Returns'] = dataset['Adj Close'].pct_change()\n",
        "dataset = dataset.dropna()\n",
        "dataset.head()"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "execution_count": 3,
          "data": {
            "text/plain": [
              "                 Open       High        Low      Close  Adj Close    Volume  \\\n",
              "Date                                                                          \n",
              "2007-01-04  19.660000  19.860001  19.320000  19.790001  19.790001  23652500   \n",
              "2007-01-05  19.540001  19.910000  19.540001  19.709999  19.709999  15902400   \n",
              "2007-01-08  19.709999  19.860001  19.370001  19.469999  19.469999  15814800   \n",
              "2007-01-09  19.450001  19.709999  19.370001  19.650000  19.650000  14494200   \n",
              "2007-01-10  19.639999  20.020000  19.500000  20.010000  20.010000  19783200   \n",
              "\n",
              "            Increase_Decrease  Buy_Sell_on_Open  Buy_Sell   Returns  \n",
              "Date                                                                 \n",
              "2007-01-04                  0                 0         0  0.013832  \n",
              "2007-01-05                  0                 1         0 -0.004043  \n",
              "2007-01-08                  0                 0         1 -0.012177  \n",
              "2007-01-09                  1                 1         1  0.009245  \n",
              "2007-01-10                  1                 1         1  0.018321  "
            ],
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>Open</th>\n",
              "      <th>High</th>\n",
              "      <th>Low</th>\n",
              "      <th>Close</th>\n",
              "      <th>Adj Close</th>\n",
              "      <th>Volume</th>\n",
              "      <th>Increase_Decrease</th>\n",
              "      <th>Buy_Sell_on_Open</th>\n",
              "      <th>Buy_Sell</th>\n",
              "      <th>Returns</th>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Date</th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>2007-01-04</th>\n",
              "      <td>19.660000</td>\n",
              "      <td>19.860001</td>\n",
              "      <td>19.320000</td>\n",
              "      <td>19.790001</td>\n",
              "      <td>19.790001</td>\n",
              "      <td>23652500</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0.013832</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2007-01-05</th>\n",
              "      <td>19.540001</td>\n",
              "      <td>19.910000</td>\n",
              "      <td>19.540001</td>\n",
              "      <td>19.709999</td>\n",
              "      <td>19.709999</td>\n",
              "      <td>15902400</td>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>0</td>\n",
              "      <td>-0.004043</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2007-01-08</th>\n",
              "      <td>19.709999</td>\n",
              "      <td>19.860001</td>\n",
              "      <td>19.370001</td>\n",
              "      <td>19.469999</td>\n",
              "      <td>19.469999</td>\n",
              "      <td>15814800</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>-0.012177</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2007-01-09</th>\n",
              "      <td>19.450001</td>\n",
              "      <td>19.709999</td>\n",
              "      <td>19.370001</td>\n",
              "      <td>19.650000</td>\n",
              "      <td>19.650000</td>\n",
              "      <td>14494200</td>\n",
              "      <td>1</td>\n",
              "      <td>1</td>\n",
              "      <td>1</td>\n",
              "      <td>0.009245</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2007-01-10</th>\n",
              "      <td>19.639999</td>\n",
              "      <td>20.020000</td>\n",
              "      <td>19.500000</td>\n",
              "      <td>20.010000</td>\n",
              "      <td>20.010000</td>\n",
              "      <td>19783200</td>\n",
              "      <td>1</td>\n",
              "      <td>1</td>\n",
              "      <td>1</td>\n",
              "      <td>0.018321</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ]
          },
          "metadata": {}
        }
      ],
      "execution_count": 3,
      "metadata": {
        "collapsed": false,
        "outputHidden": false,
        "inputHidden": false
      }
    },
    {
      "cell_type": "code",
      "source": [
        "dataset.shape"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "execution_count": 4,
          "data": {
            "text/plain": [
              "(2991, 10)"
            ]
          },
          "metadata": {}
        }
      ],
      "execution_count": 4,
      "metadata": {
        "collapsed": false,
        "outputHidden": false,
        "inputHidden": false
      }
    },
    {
      "cell_type": "code",
      "source": [
        "X = dataset[['Open', 'High', 'Low', 'Volume']].values\n",
        "y = dataset['Buy_Sell'].values"
      ],
      "outputs": [],
      "execution_count": 5,
      "metadata": {
        "collapsed": false,
        "outputHidden": false,
        "inputHidden": false
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import catboost as cb\n",
        "from catboost import CatBoostClassifier\n",
        "from sklearn import metrics\n",
        "from sklearn.model_selection import train_test_split\n",
        "# from sklearn.model_selection import GridSearchCV"
      ],
      "outputs": [],
      "execution_count": 6,
      "metadata": {
        "collapsed": false,
        "outputHidden": false,
        "inputHidden": false
      }
    },
    {
      "cell_type": "code",
      "source": [
        "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 0)"
      ],
      "outputs": [],
      "execution_count": 7,
      "metadata": {
        "collapsed": false,
        "outputHidden": false,
        "inputHidden": false
      }
    },
    {
      "cell_type": "code",
      "source": [
        "cb = CatBoostClassifier(iterations=5, learning_rate=0.1)\n",
        "cb.fit(X_train, y_train)"
      ],
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0:\tlearn: 0.6923219\ttotal: 150ms\tremaining: 599ms\n",
            "1:\tlearn: 0.6914714\ttotal: 241ms\tremaining: 362ms\n",
            "2:\tlearn: 0.6904868\ttotal: 339ms\tremaining: 226ms\n",
            "3:\tlearn: 0.6895988\ttotal: 433ms\tremaining: 108ms\n",
            "4:\tlearn: 0.6889079\ttotal: 520ms\tremaining: 0us\n"
          ]
        },
        {
          "output_type": "execute_result",
          "execution_count": 8,
          "data": {
            "text/plain": [
              "<catboost.core.CatBoostClassifier at 0x24d1c4560f0>"
            ]
          },
          "metadata": {}
        }
      ],
      "execution_count": 8,
      "metadata": {
        "collapsed": false,
        "outputHidden": false,
        "inputHidden": false
      }
    },
    {
      "cell_type": "code",
      "source": [
        "print('Model is fitted: ' + str(cb.is_fitted()))\n",
        "print('Model params:')\n",
        "print(cb.get_params())"
      ],
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model is fitted: True\n",
            "Model params:\n",
            "{'loss_function': 'Logloss', 'learning_rate': 0.1, 'iterations': 5}\n"
          ]
        }
      ],
      "execution_count": 9,
      "metadata": {
        "collapsed": false,
        "outputHidden": false,
        "inputHidden": false
      }
    },
    {
      "cell_type": "code",
      "source": [
        "y_pred = cb.predict(X_test)"
      ],
      "outputs": [],
      "execution_count": 10,
      "metadata": {
        "collapsed": false,
        "outputHidden": false,
        "inputHidden": false
      }
    },
    {
      "cell_type": "code",
      "source": [
        "y_pred"
      ],
      "outputs": [
        {
          "output_type": "execute_result",
          "execution_count": 11,
          "data": {
            "text/plain": [
              "array([ 1.,  0.,  0.,  0.,  0.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,\n",
              "        0.,  0.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
              "        0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
              "        0.,  1.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,\n",
              "        1.,  0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
              "        1.,  1.,  1.,  0.,  1.,  0.,  0.,  1.,  1.,  0.,  0.,  0.,  1.,\n",
              "        0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  0.,  0.,\n",
              "        1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,\n",
              "        0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  1.,\n",
              "        0.,  0.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
              "        0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  1.,\n",
              "        1.,  0.,  1.,  0.,  0.,  1.,  0.,  0.,  1.,  1.,  1.,  0.,  1.,\n",
              "        0.,  0.,  0.,  0.,  0.,  1.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,\n",
              "        0.,  1.,  0.,  1.,  0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,\n",
              "        1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,\n",
              "        0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,\n",
              "        0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  1.,  0.,  1.,  0.,\n",
              "        1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,\n",
              "        0.,  0.,  0.,  1.,  1.,  0.,  0.,  1.,  1.,  0.,  0.,  0.,  0.,\n",
              "        1.,  0.,  0.,  0.,  1.,  1.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,\n",
              "        1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
              "        0.,  1.,  0.,  0.,  1.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,\n",
              "        1.,  1.,  0.,  0.,  1.,  1.,  1.,  0.,  1.,  1.,  0.,  1.,  0.,\n",
              "        1.,  0.,  1.,  1.,  0.,  0.,  1.,  0.,  1.,  0.,  0.,  0.,  1.,\n",
              "        0.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  0.,  0.,  1.,\n",
              "        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,\n",
              "        0.,  1.,  0.,  1.,  0.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,\n",
              "        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,  0.,\n",
              "        1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,  0.,  0.,  0.,\n",
              "        0.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,\n",
              "        1.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,  1.,  1.,\n",
              "        0.,  0.,  0.,  1.,  1.,  0.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,\n",
              "        0.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,\n",
              "        1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  0.,  1.,  0.,  0.,  0.,\n",
              "        0.,  1.,  1.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,\n",
              "        0.,  0.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  0.,  1.,  0.,\n",
              "        1.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  0.,\n",
              "        0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,\n",
              "        0.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,  1.,  0.,\n",
              "        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
              "        0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,\n",
              "        1.,  0.,  1.,  0.,  0.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  1.,\n",
              "        0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,\n",
              "        0.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,  1.,  1.,  1.,  0.,\n",
              "        0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,\n",
              "        0.,  0.,  0.,  0.,  0.,  1.,  0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.])"
            ]
          },
          "metadata": {}
        }
      ],
      "execution_count": 11,
      "metadata": {
        "collapsed": false,
        "outputHidden": false,
        "inputHidden": false
      }
    },
    {
      "cell_type": "code",
      "source": [
        "print('CatBoost Score:', cb.score(X_test, y_test))"
      ],
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "CatBoost Score: 0.544240400668\n"
          ]
        }
      ],
      "execution_count": 12,
      "metadata": {
        "collapsed": false,
        "outputHidden": false,
        "inputHidden": false
      }
    }
  ],
  "metadata": {
    "kernel_info": {
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.5.5",
      "file_extension": ".py",
      "nbconvert_exporter": "python",
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "pygments_lexer": "ipython3",
      "mimetype": "text/x-python"
    },
    "kernelspec": {
      "name": "python3",
      "language": "python",
      "display_name": "Python 3"
    },
    "nteract": {
      "version": "0.15.0"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}